import os
import pysam
from numpy import *
from pylab import *
from matplotlib import ticker


libraries = ("cel_r1",
             "neg_r1",
             "lip_r1",
             "n91_r1",
             "n91_r2",
             "n91_r3",
             "c91_r1",
             "c91_r2",
             "c91_r3",
             "nkd_r1",
             "nkd_r2",
             "nkd_r3",
             "myb_r1",
             "myb_r2",
             "myb_r3",
             "gfi_r1",
             "gfi_r2",
             "gfi_r3",
             "t00_r1",
             "t00_r2",
             "t00_r3",
             "t01_r1",
             "t01_r2",
             "t01_r3",
             "t04_r1",
             "t04_r2",
             "t04_r3",
             "t12_r1",
             "t12_r2",
             "t12_r3",
             "t24_r1",
             "t24_r2",
             "t24_r3",
             "t96_r1",
             "t96_r2",
             "t96_r3",
            )

conditions = {"cel_r1": "Untreated cells",
              "lip_r1": "Lipofectamine only",
              "neg_r1": "Protocol negative control",
              "myb_r1": "MYB CASPAR replicate 1",
              "myb_r2": "MYB CASPAR replicate 2",
              "myb_r3": "MYB CASPAR replicate 3",
              "gfi_r1": "GFI1 PROMPT replicate 1",
              "gfi_r2": "GFI1 PROMPT replicate 2",
              "gfi_r3": "GFI1 PROMPT replicate 3",
              "nkd_r1": "NC_A (MYB/GFI1) replicate 1",
              "nkd_r2": "NC_A (MYB/GFI1) replicate 2",
              "nkd_r3": "NC_A (MYB/GFI1) replicate 3",
              "c91_r1": "STAT6 PROMPT replicate 1",
              "c91_r2": "STAT6 PROMPT replicate 2",
              "c91_r3": "STAT6 PROMPT replicate 3",
              "n91_r1": "NC_A (STAT6) replicate 1",
              "n91_r2": "NC_A (STAT6) replicate 2",
              "n91_r3": "NC_A (STAT6) replicate 3",
             }


def kilo(x, pos):
    if x > 1000:
        return '%dk' % (x//1000)
    else:
        return '%d' % x

formatter = ticker.FuncFormatter(kilo)

directory = "/osc-fs_home/mdehoon/Data/CASPARs/MiSeq/Mapping"

fig = figure(figsize=(6,12))

ax = fig.add_subplot(111)
ax.spines['top'].set_color('none')
ax.spines['bottom'].set_color('none')
ax.spines['left'].set_color('none')
ax.spines['right'].set_color('none')
ax.tick_params(labelcolor='w', top=False, bottom=False, left=False, right=False)
ax.set_xlabel("RNA size [bp]")
ax.set_ylabel("Read count", labelpad=27)

timepoints = ("0 hours", "1 hour", "4 hours", "12 hours", "24 hours", "96 hours")

i = 0
for library in libraries[18:36]:
    filename = "%s.bam"% library
    unmapped = 0
    outside = 0
    within = 0
    qnames = []
    offset = 128
    xmin = 180
    xmax = 420
    counts = zeros(1000)
    path = os.path.join(directory, filename)
    print("Reading", path)
    alignments = pysam.AlignmentFile(path)
    for alignment1 in alignments:
        alignment2 = next(alignments)
        assert alignment1.is_read1
        assert alignment2.is_read2
        try:
            target = alignment1.get_tag("XT")
        except KeyError:
            assert alignment1.is_unmapped
            assert alignment2.is_unmapped
            try:
                length = alignment1.get_tag("XL")
            except KeyError:
                qnames.append(alignment1.qname)
                unmapped += 1
                continue
            else:
                raise Exception
        if target == 'genome':
            if alignment1.is_reverse:
                assert not alignment2.is_reverse
                start = alignment2.reference_start
                end = alignment1.reference_end
            else:
                assert alignment2.is_reverse
                start = alignment1.reference_start
                end = alignment2.reference_end
            assert start < end
            length = end - start
        else:
            length = alignment1.get_tag("XL")
        length += offset
        if alignment1.is_unmapped:
            assert alignment2.is_unmapped
            weight = 1.0
            try:
                multimap = alignment1.get_tag("NH")
            except KeyError:
                pass
            else:
                raise Exception
            qnames.append(alignment1.qname)
        else:
            assert not alignment2.is_unmapped
            multimap = alignment1.get_tag("NH")
            weight = 1.0 / multimap
        if length < 200 or length > 400:
            outside += weight
        else:
            within += weight
        if length < 1000:
            counts[length] += weight
    ax = fig.add_subplot(6,3,i+1)
    labeled = False
    for x, height in enumerate(counts):
        if height > 0:
            if labeled is False:
                plot([x, x], [0, height], color='black', label='Mapped sequences')
                labeled = True
            else:
                plot([x, x], [0, height], color='black')
    ymin, ymax = ylim()
    ymin = 0
    plot([200, 200], [ymin, ymax], 'r--')
    plot([400, 400], [ymin, ymax], 'r--', label='Selected size limits')
    if i in (15, 16, 17):
        sizes = (200, 300, 400)
        labels = [str(size-offset) for size in sizes]
        xticks(sizes, labels, fontsize=8)
    else:
        xticks([])
    ax.yaxis.set_major_formatter(formatter)
    yticks(fontsize=8)
    if i % 3 == 0:
        j = i // 3
        timepoint = timepoints[j]
        ylabel(timepoint, fontsize=8)
    if i < 3:
        title("Replicate %d" % (i+1), fontsize=8)
    xlim(xmin, xmax)
    ylim(ymin, ymax)
    outside += sum(counts[:xmin]) + sum(counts[xmax+1:])
    outside = int(round(outside))
    within = int(round(within))
    unmapped = int(round(unmapped))
    description = "%d;\n%d;\n%d" % (within, outside, unmapped)
    text(205, 0.9 * ymax, description, fontsize=8, verticalalignment='top')
    assert len(qnames) == len(set(qnames))
    i += 1

handles, labels = ax.get_legend_handles_labels()
legend(handles[::-1], labels[::-1], bbox_to_anchor=(1.0, -0.2), fontsize=8)

subplots_adjust(bottom=0.08, top=0.97, left=0.15, right=0.98, wspace=0.3)

filename = "figure_sequencing_miseq_timecourse.svg"
print("Saving figure as %s" % filename)
savefig(filename)

filename = "figure_sequencing_miseq_timecourse.png"
print("Saving figure as %s" % filename)
savefig(filename)


fig = figure(figsize=(6,12))

ax = fig.add_subplot(111)
ax.spines['top'].set_color('none')
ax.spines['bottom'].set_color('none')
ax.spines['left'].set_color('none')
ax.spines['right'].set_color('none')
ax.tick_params(labelcolor='w', top=False, bottom=False, left=False, right=False)
ax.set_xlabel("RNA size [bp]")
ax.set_ylabel("Read count")


i = 0
for library in libraries[:18]:
    filename = "%s.bam"% library
    i += 1
    unmapped = 0
    outside = 0
    within = 0
    qnames = []
    offset = 128
    xmin = 180
    xmax = 420
    x = arange(1000)
    counts = zeros(1000)
    path = os.path.join(directory, filename)
    print("Reading", path)
    alignments = pysam.AlignmentFile(path)
    for alignment1 in alignments:
        alignment2 = next(alignments)
        assert alignment1.is_read1
        assert alignment2.is_read2
        try:
            target = alignment1.get_tag("XT")
        except KeyError:
            assert alignment1.is_unmapped
            assert alignment2.is_unmapped
            try:
                length = alignment1.get_tag("XL")
            except KeyError:
                qnames.append(alignment1.qname)
                unmapped += 1
                continue
            else:
                raise Exception
        if target == 'genome':
            if alignment1.is_reverse:
                assert not alignment2.is_reverse
                start = alignment2.reference_start
                end = alignment1.reference_end
            else:
                assert alignment2.is_reverse
                start = alignment1.reference_start
                end = alignment2.reference_end
            assert start < end
            length = end - start
        else:
            length = alignment1.get_tag("XL")
        length += offset
        if alignment1.is_unmapped:
            assert alignment2.is_unmapped
            weight = 1.0
            try:
                multimap = alignment1.get_tag("NH")
            except KeyError:
                pass
            else:
                raise Exception
            qnames.append(alignment1.qname)
        else:
            assert not alignment2.is_unmapped
            multimap = alignment1.get_tag("NH")
            weight = 1.0 / multimap
        if length < 200 or length > 400:
            outside += weight
        else:
            within += weight
        if length < 1000:
            counts[length] += weight
    ax = fig.add_subplot(6,3,i)
    labeled = False
    for x, height in enumerate(counts):
        if height > 0:
            if labeled is False:
                plot([x, x], [0, height], color='black', label='Mapped sequences')
                labeled = True
            else:
                plot([x, x], [0, height], color='black')
    ymin, ymax = ylim()
    ymin = 0
    plot([200, 200], [ymin, ymax], 'r--')
    plot([400, 400], [ymin, ymax], 'r--', label='Selected size limits')
    if i in (16, 17, 18):
        sizes = (200, 300, 400)
        labels = [str(size-offset) for size in sizes]
        xticks(sizes, labels, fontsize=8)
    else:
        xticks([])
    ax.yaxis.set_major_formatter(formatter)
    yticks(fontsize=8)
    title(conditions[library], fontsize=8, pad=2)
    xlim(xmin, xmax)
    ylim(ymin, ymax)
    outside += sum(counts[:xmin]) + sum(counts[xmax+1:])
    outside = int(round(outside))
    within = int(round(within))
    unmapped = int(round(unmapped))
    description = "%d;\n%d;\n%d" % (within, outside, unmapped)
    text(205, 0.9 * ymax, description, fontsize=8, verticalalignment='top')
    assert len(qnames) == len(set(qnames))

handles, labels = ax.get_legend_handles_labels()
legend(handles[::-1], labels[::-1], bbox_to_anchor=(1.0, -0.2), fontsize=8)

subplots_adjust(bottom=0.08, top=0.97, left=0.15, right=0.98, wspace=0.3)

filename = "figure_sequencing_miseq_knockdown.svg"
print("Saving figure as %s" % filename)
savefig(filename)

filename = "figure_sequencing_miseq_knockdown.png"
print("Saving figure as %s" % filename)
savefig(filename)
